package com.tiger.chat.provider;

import cn.hutool.core.util.StrUtil;
import com.theokanning.openai.OpenAiService;
import com.theokanning.openai.completion.CompletionChoice;
import com.theokanning.openai.completion.CompletionRequest;
import com.tiger.chat.config.ChatProperty;
import com.tiger.chat.config.WebSocketServer;
import com.tiger.chat.entity.OpenAi;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Component;

import java.util.Arrays;
import java.util.List;

@Slf4j
@Component
public class OpenAIProvider implements InitializingBean {

    private static final OpenAi COMMON_CHAT =  new OpenAi("COMMON_CHAT", "正常对话", "正常对话", "text-davinci-003", "%s", 0.9, 1.0, 1.0, 0.0, 0.6, "");
    private static final OpenAi CODE_CHAT = new OpenAi("CODE_CHAT", "代码相关", "代码相关", "code-davinci-002", "%s", 0.0, 1.0, 1.0, 0.5, 0.0, "");

    private final OpenAiService service;

    public OpenAIProvider(ChatProperty chatProperty) {
        int timeout = chatProperty.getTimeout();
        if (timeout < 1000) {
            timeout = 3000;
        }
        this.service = new OpenAiService(chatProperty.getToken(), timeout);
    }

    public String sendChatMsg(String msg){
        return getAiResultStr(COMMON_CHAT,msg);
    }

    public String sendCodeMsg(String msg){
        return getAiResultStr(CODE_CHAT,msg);
    }

    public  List<CompletionChoice> getAiResult(OpenAi openAi, String prompt) {
        log.info("发送给OpenAi处理：{}，{}",openAi.getModel(),prompt);
        CompletionRequest.CompletionRequestBuilder builder = CompletionRequest.builder()
                .model(openAi.getModel())
                .prompt(prompt)
                .temperature(openAi.getTemperature())
                .maxTokens(1000)
                .topP(openAi.getTopP())
                .frequencyPenalty(openAi.getFrequencyPenalty())
                .presencePenalty(openAi.getPresencePenalty());
        if (StrUtil.isNotBlank(openAi.getStop())) {
            builder.stop(Arrays.asList(openAi.getStop().split(",")));
        }
        CompletionRequest completionRequest = builder.build();
        return service.createCompletion(completionRequest).getChoices();
    }

    public  String getAiResultStr(OpenAi openAi, String prompt) {
        List<CompletionChoice> choices=getAiResult(openAi,prompt);
        StringBuilder sb = new StringBuilder();
        for(CompletionChoice choice:choices){
            sb.append(choice.getText()).append(";");
        }
        return sb.toString();
    }

    @Override
    public void afterPropertiesSet() throws Exception {
        WebSocketServer.initAI(this);
        log.info("AI服务初始化...");
    }
}
